"""
Phase 2: Baseline Regressions — Real Exchange Rate
=====================================================
Tests whether demographic structure predicts real effective exchange
rates across 104 countries — first global test extending
Groneck & Kaufmann's OECD-only findings.

Key models:
  (a) Z → log(REER) with Balassa-Samuelson controls
  (b) Z → Δlog(REER)
  (c) OECD vs non-OECD vs full
  (d) Age ratios (old_dep, youth_dep) → REER

Output: output/tables/baseline_reer.md, reer_subsamples.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/rer")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS and return results dict."""
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing — skipping")
        return None

    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None

    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)

    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}")

    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<25} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")

    return results


def build_table(results, key_vars, controls_label, filename, title):
    if not results:
        return

    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")

    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")

    md.append(f"\n*{controls_label}*")

    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 2: Baseline Regressions — Real Exchange Rate")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "rer_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")
    print(f"REER coverage: {df['reer_combined'].notna().sum():,} obs, "
          f"{df.loc[df['reer_combined'].notna(), 'iso3'].nunique()} countries")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    age_vars = ['old_dep', 'youth_dep']

    # ═══════════════════════════════════════════════════════════════════
    # PART A: Level models — Z → log(REER)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: LEVEL MODELS — Z → log(REER)")
    print("=" * 70)

    level_results = []

    # Standard controls (Paper 6 style)
    controls_basic = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
    # Balassa-Samuelson controls (the key extension)
    controls_bs = controls_basic + ['log_gdp_pc']
    # Extended BS
    controls_bs_ext = controls_basic + ['log_gdp_pc', 'trade_openness']

    # A1: Z → log REER (no BS controls) — replicates Paper 6 on wider sample
    r = run_model(df, 'log_reer_combined', demo_vars + controls_basic,
                  "A1: Z (no BS control)")
    if r: level_results.append(r)

    # A2: Z → log REER (with GDP/capita — BS control)
    r = run_model(df, 'log_reer_combined', demo_vars + controls_bs,
                  "A2: Z + BS (log GDP/pc)")
    if r: level_results.append(r)

    # A3: Z → log REER (BS + trade openness)
    r = run_model(df, 'log_reer_combined', demo_vars + controls_bs_ext,
                  "A3: Z + BS + trade")
    if r: level_results.append(r)

    # A4: Z → log REER (with health_exp as non-tradable proxy)
    controls_nontrad = controls_bs + ['health_exp_gdp']
    r = run_model(df, 'log_reer_combined', demo_vars + controls_nontrad,
                  "A4: Z + BS + health")
    if r: level_results.append(r)

    # A5: age ratios → log REER (BS controls)
    r = run_model(df, 'log_reer_combined', age_vars + controls_bs,
                  "A5: age ratios + BS")
    if r: level_results.append(r)

    # A6: OECD subsample (comparable to Groneck-Kaufmann)
    oecd = df[df['oecd'] == 1].copy()
    r = run_model(oecd, 'log_reer_combined', demo_vars + controls_bs,
                  "A6: OECD Z + BS")
    if r: level_results.append(r)

    # A7: non-OECD subsample (novel)
    non_oecd = df[df['oecd'] == 0].copy()
    r = run_model(non_oecd, 'log_reer_combined', demo_vars + controls_bs,
                  "A7: non-OECD Z + BS")
    if r: level_results.append(r)

    key_level_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                      'log_gdp_pc', 'health_exp_gdp', 'trade_openness']
    build_table(level_results, key_level_vars,
                f"Controls: {', '.join(controls_basic)}",
                "baseline_reer_level.md",
                "Baseline: Demographics → log(REER) — Level")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: Change models — Z → Δlog(REER)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: CHANGE MODELS — Z → Δlog(REER)")
    print("=" * 70)

    change_results = []

    # B1: Z → Δlog REER
    r = run_model(df, 'd_log_reer', demo_vars + controls_basic,
                  "B1: Z → Δlog REER")
    if r: change_results.append(r)

    # B2: Z → Δlog REER + BS
    r = run_model(df, 'd_log_reer', demo_vars + controls_bs,
                  "B2: Z → Δlog REER + BS")
    if r: change_results.append(r)

    # B3: age ratios → Δlog REER
    r = run_model(df, 'd_log_reer', age_vars + controls_bs,
                  "B3: age ratios → Δlog REER")
    if r: change_results.append(r)

    # B4: OECD Δlog REER
    r = run_model(oecd, 'd_log_reer', demo_vars + controls_bs,
                  "B4: OECD Z → Δlog REER")
    if r: change_results.append(r)

    # B5: non-OECD Δlog REER
    r = run_model(non_oecd, 'd_log_reer', demo_vars + controls_bs,
                  "B5: non-OECD Z → Δlog REER")
    if r: change_results.append(r)

    key_change_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep', 'log_gdp_pc']
    build_table(change_results, key_change_vars,
                f"Controls: {', '.join(controls_basic)}",
                "baseline_reer_change.md",
                "Baseline: Demographics → Δlog(REER) — Change")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: Income terciles & Eurozone
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: INCOME TERCILES & EUROZONE")
    print("=" * 70)

    subsample_results = []

    # C1-C3: Income terciles
    for group in ['low', 'mid', 'high']:
        sub = df[df['income_group'] == group].copy()
        r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                      f"C{['low','mid','high'].index(group)+1}: {group}-income")
        if r: subsample_results.append(r)

    # C4: Eurozone members
    emu = df[df['eurozone'] == 1].copy()
    r = run_model(emu, 'log_reer_combined', demo_vars + controls_bs,
                  "C4: Eurozone")
    if r: subsample_results.append(r)

    # C5: Non-eurozone
    non_emu = df[df['eurozone'] == 0].copy()
    r = run_model(non_emu, 'log_reer_combined', demo_vars + controls_bs,
                  "C5: Non-eurozone")
    if r: subsample_results.append(r)

    key_sub_vars = ['Z_1', 'Z_2', 'Z_3', 'log_gdp_pc']
    build_table(subsample_results, key_sub_vars,
                f"Controls: {', '.join(controls_bs)}",
                "reer_subsamples.md",
                "Demographics → log(REER) — Income & Currency Subsamples")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: PPP deviation (demeaned REER)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: PPP DEVIATION (DEMEANED REER)")
    print("=" * 70)

    ppd_results = []

    # D1: Z → demeaned REER
    r = run_model(df, 'log_reer_demeaned', demo_vars + controls_basic,
                  "D1: Z → PPP deviation")
    if r: ppd_results.append(r)

    # D2: Z → demeaned REER + BS
    r = run_model(df, 'log_reer_demeaned', demo_vars + controls_bs,
                  "D2: Z → PPP deviation + BS")
    if r: ppd_results.append(r)

    # D3: OECD PPP deviation
    r = run_model(oecd, 'log_reer_demeaned', demo_vars + controls_bs,
                  "D3: OECD PPP deviation")
    if r: ppd_results.append(r)

    # D4: non-OECD PPP deviation
    r = run_model(non_oecd, 'log_reer_demeaned', demo_vars + controls_bs,
                  "D4: non-OECD PPP deviation")
    if r: ppd_results.append(r)

    key_ppd_vars = ['Z_1', 'Z_2', 'Z_3', 'log_gdp_pc']
    build_table(ppd_results, key_ppd_vars,
                "PPP deviation = log(REER) - country mean",
                "reer_ppd.md",
                "Demographics → PPP Deviation")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: Lagged & first-differenced demographics
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: LAGGED & FIRST-DIFFERENCED DEMOGRAPHICS")
    print("=" * 70)

    lag_results = []
    lag_vars = ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5']
    diff_vars = ['d_Z_1', 'd_Z_2', 'd_Z_3']

    # E1: 5yr lag → log REER
    r = run_model(df, 'log_reer_combined', lag_vars + controls_bs,
                  "E1: Z_lag5 → log REER")
    if r: lag_results.append(r)

    # E2: first-diff → log REER
    r = run_model(df, 'log_reer_combined', diff_vars + controls_bs,
                  "E2: ΔZ → log REER")
    if r: lag_results.append(r)

    # E3: predetermined OADR+20
    r = run_model(df, 'log_reer_combined', ['oadr_plus20'] + controls_bs,
                  "E3: OADR+20 → log REER")
    if r: lag_results.append(r)

    # E4: 5yr lag → Δlog REER
    r = run_model(df, 'd_log_reer', lag_vars + controls_bs,
                  "E4: Z_lag5 → Δlog REER")
    if r: lag_results.append(r)

    key_lag_vars = ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5',
                    'd_Z_1', 'd_Z_2', 'd_Z_3', 'oadr_plus20']
    build_table(lag_results, key_lag_vars,
                "Lagged and differenced demographics",
                "reer_lagged.md",
                "Lagged & Differenced Demographics → REER")

    print("\n" + "=" * 70)
    print("Phase 2 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
